pub mod attribute;

use std::{any::Any, rc::Rc};

use impl_downcast::impl_downcast;

use crate::{AsBytes, Context};

use self::attribute::Uniform;

pub trait Attribute: Any {
    fn ty(&self) -> wgpu::BindingType;
    fn res(&self) -> wgpu::BindingResource;
    fn prepare(&mut self, ctx: &Context);
}

impl_downcast!(Attribute);

pub struct Material {
    pub layout: Rc<wgpu::BindGroupLayout>,
    pub group: wgpu::BindGroup,
    pub attributes: Vec<Box<dyn Attribute>>,
}

impl Material {
    pub fn builder() -> MaterialBuilder {
        MaterialBuilder { attributes: vec![] }
    }

    pub fn attr<A: Attribute>(&self, bind_index: u32) -> Option<&A> {
        self.attributes
            .get(bind_index as usize)
            .and_then(|a| a.downcast())
    }

    pub fn attr_mut<A: Attribute>(&mut self, bind_index: u32) -> Option<&mut A> {
        self.attributes
            .get_mut(bind_index as usize)
            .and_then(|a| a.downcast_mut())
    }

    pub fn prepare(&mut self, ctx: &Context) {
        for attr in self.attributes.iter_mut() {
            attr.prepare(ctx);
        }
    }

    pub fn layout(&self) -> &wgpu::BindGroupLayout {
        &self.layout
    }

    pub fn group(&self) -> &wgpu::BindGroup {
        &self.group
    }
}

pub struct MaterialBuilder {
    attributes: Vec<(Box<dyn Attribute>, wgpu::ShaderStages)>,
}

impl MaterialBuilder {
    pub fn unifrom<T: 'static + AsBytes>(
        self,
        ctx: &Context,
        data: T,
        visibility: wgpu::ShaderStages,
    ) -> Self {
        self.attr(Uniform::new(ctx, data), visibility)
    }

    pub fn attr(mut self, attr: impl Attribute, visibility: wgpu::ShaderStages) -> Self {
        self.attributes.push((Box::new(attr), visibility));
        self
    }

    pub fn build_with_layout(self, ctx: &Context, layout: Rc<wgpu::BindGroupLayout>) -> Material {
        let entries: Vec<_> = self
            .attributes
            .iter()
            .enumerate()
            .map(|(idx, (a, _))| wgpu::BindGroupEntry {
                binding: idx as u32,
                resource: a.res(),
            })
            .collect();

        let group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
            label: None,
            layout: &layout,
            entries: &entries,
        });

        Material {
            layout,
            group,
            attributes: self.attributes.into_iter().map(|(a, _)| a).collect(),
        }
    }

    pub fn build(self, ctx: &Context) -> Material {
        let layout_entries: Vec<_> = self
            .attributes
            .iter()
            .enumerate()
            .map(|(idx, (a, visibility))| wgpu::BindGroupLayoutEntry {
                binding: idx as u32,
                visibility: *visibility,
                ty: a.ty(),
                count: None,
            })
            .collect();

        let layout = ctx.create_bind_group_layout(&layout_entries);

        self.build_with_layout(ctx, Rc::new(layout))
    }
}
